import os
import torch
import torch.nn as nn
import numpy as np
from utils import get_metrics4intent, get_metrics4slot
from seqeval.metrics.sequence_labeling import get_entities


# 训练器
class Trainer:
    def __init__(self, model, config) -> None:
        self.model = model # 模型
        self.config = config # 各种配置
        self.criterion = nn.CrossEntropyLoss() # 损失函数
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr) # 优化器
        self.epoch = self.config.epoch # 迭代代数
        self.device = self.config.device # 设备

    def set_input2device(self, input):
        for key in input.keys():
            input[key] = input[key].to(self.device)
        return input
        
    # 保存模型
    def save_model(self, save_path, save_name):
        torch.save(self.model.state_dict(), os.path.join(save_path, save_name))
        
    # 训练
    def train(self, train_data_loader):
        global_step = 0
        total_step = len(train_data_loader) * self.epoch # 训练步数
        self.model.train()
        for epoch in range(self.epoch):
            total_loss = 0
            for step, batch in enumerate(train_data_loader):
                batch = self.set_input2device(batch)
                # 前向计算 
                domain_output, intent_output, slot_output = self.model(batch['input_ids'], 
                                                                       attention_mask=batch['attention_mask'], 
                                                                       token_type_ids=batch['token_type_ids'])
                # 有效位置
                active_loss = batch['attention_mask'].view(-1) == 1 # batch_size * seq_len
                active_logits = slot_output.view(-1, slot_output.shape[2])[active_loss] # 只保留attention_mask=1的logits
                active_labels = batch['slot_label_ids'].view(-1)[active_loss] # 只保留attention_mask=1的label

                domain_loss = self.criterion(domain_output, batch['domain_label_ids'])
                intent_loss = self.criterion(intent_output, batch['intent_label_ids'])
                slot_loss = self.criterion(active_logits, active_labels)
                loss = domain_loss + intent_loss + slot_loss
                total_loss += loss.item()
                self.optimizer.zero_grad()
                # backward
                loss.backward()
                self.optimizer.step()
                # print(f'[train] epoch:{epoch+1} {global_step}/{total_step} loss:{loss.item()}')
                global_step += 1

            avg_loss = total_loss / len(train_data_loader)
            print(f"Epoch {epoch+1}/{self.epoch}, Loss: {avg_loss}")
    # 评估
    def eval(self, eval_data_loader):
        self.model.eval()
        domain_preds, domain_golds = [], []
        intent_preds, intent_golds= [], []
        slot_preds, slot_golds = [], []
        with torch.no_grad():
            for step, batch in enumerate(eval_data_loader):
                batch = self.set_input2device(batch)
                domain_output, intent_output, slot_output = self.model(batch['input_ids'], 
                                                                       attention_mask=batch['attention_mask'], 
                                                                       token_type_ids=batch['token_type_ids'])
                domain_output = domain_output.detach().cpu().numpy()
                domain_output = np.argmax(domain_output, -1)
                domain_label_ids = batch['domain_label_ids'].detach().cpu().numpy()
                domain_label_ids = domain_label_ids.reshape(-1)
                domain_preds.extend(domain_output)
                domain_golds.extend(domain_label_ids)
                
                intent_output = intent_output.detach().cpu().numpy()
                intent_output = np.argmax(intent_output, -1)
                intent_label_ids = batch['intent_label_ids'].detach().cpu().numpy()
                intent_label_ids = intent_label_ids.reshape(-1)
                intent_preds.extend(intent_output)
                intent_golds.extend(intent_label_ids)
                # 解码slot
                slot_output = slot_output.detach().cpu().numpy()
                slot_label_ids = batch['slot_label_ids'].detach().cpu().numpy()
                slot_output = np.argmax(slot_output, -1)
                active_len =  torch.sum(batch['attention_mask'], -1).view(-1)
                # 遍历序列解码
                for length, t_output, t_label in zip(active_len, slot_output, slot_label_ids):
                    t_output = t_output[1:length-1]
                    t_label = t_label[1:length-1]
                    t_ouput = [self.config.id2slotlabel[i] for i in t_output]
                    t_label = [self.config.id2slotlabel[i] for i in t_label]
                    slot_preds.append(t_ouput)
                    slot_golds.append(t_label)
        domain_acc, domain_precision, domain_recall, domain_f1 = get_metrics4intent(domain_golds, domain_preds)
        print(f'领域识别：\naccuracy:{domain_acc}\nprecision:{domain_precision}\nrecall:{domain_recall}\nf1:{domain_f1}')
        intent_acc, intent_precision, intent_recall, intent_f1 = get_metrics4intent(intent_golds, intent_preds)
        print(f'意图识别：\naccuracy:{intent_acc}\nprecision:{intent_precision}\nrecall:{intent_recall}\nf1:{intent_f1}')
        slot_acc, slot_precision, slot_recall, slot_f1 = get_metrics4slot(slot_golds, slot_preds)
        print(f'槽填充：\naccuracy:{slot_acc}\nprecision:{slot_precision}\nrecall:{slot_recall}\nf1:{slot_f1}')

    # 预测
    def predict(self, text, tokenizer):
        self.model.eval()
        with torch.no_grad():
            tmp_text = [i for i in text]
            inputs = tokenizer.encode_plus(
                text=tmp_text,
                max_length=self.config.max_len,
                padding='max_length',
                truncation='only_first',
                return_attention_mask=True,
                return_token_type_ids=True,
                return_tensors='pt'
            )
            inputs = self.set_input2device(inputs)
            
            domain_output, intent_output, slot_output = self.model(inputs['input_ids'], 
                                                                   inputs['attention_mask'], 
                                                                   inputs['token_type_ids'])
            domain_output = domain_output.detach().cpu().numpy()
            domain_output = np.argmax(domain_output, -1)[0]
            print('领域：', self.config.id2domainlabel[domain_output])
            intent_output = intent_output.detach().cpu().numpy()
            intent_output  = np.argmax(intent_output, -1)[0]
            print('意图：', self.config.id2intentlabel[intent_output])
            slot_output = slot_output.detach().cpu().numpy()
            slot_output = np.argmax(slot_output, -1)
            slot_output = slot_output[0][1:len(text)-1]
            slot_output = [self.config.id2slotlabel[i] for i in slot_output]
            print('槽位：', str([(i[0],text[i[1]:i[2]+1], i[1], i[2]) for i in get_entities(slot_output)]))

            return {
                'domain': self.config.id2domainlabel[domain_output],
                'intent': self.config.id2intentlabel[intent_output], 
                'slots': [(i[0],text[i[1]:i[2]+1], i[1], i[2]) for i in get_entities(slot_output)]
            }
